import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))

import json
import random
from argparse import ArgumentParser
from image_synthesis.data.utils.tsv_file import TSVFile, CompositeTSVFile, tsv_writer
from image_synthesis.utils.io import save_config_to_yaml
from torch.utils.data.dataloader import DataLoader
from collections import defaultdict
import torch.utils.data as data
from PIL import Image
import sys
import torch
import numpy as np
from collections import OrderedDict



class TSVTextDataset(data.Dataset):
    """
        This class is intended for encapsulating Image/Text pair data for contrastive learning described in
        the following paper,
        "Learning Transferable Visual Models From Natural Language Supervision" (a.k.a CLIP)
    """
    def __init__(self,
                 name,
                 text_tsv_file,
                 data_root='',
                 num_captions=1,
                 text_format='txt'):
        
        self.name = name
        self.data_root = 'data' if data_root == '' else data_root
        
        text_tsv_file = [os.path.join(self.data_root, name, tf) for tf in text_tsv_file]

        self._chunk_sizes = None
        self.num_captions = num_captions
        self.text_format = text_format

        if len(text_tsv_file) == 1:
            text_tsv_file = text_tsv_file[0]

        if isinstance(text_tsv_file, str):
            # single tsv file
            if (
                os.path.splitext(text_tsv_file)[1].lower() == '.tsv'
            ):            
                self.text_tsv_file = TSVFile(text_tsv_file, if_generate_lineidx=True)
            # multiple tsv files specified in a text file
            elif (
                os.path.splitext(text_tsv_file)[1].lower() == '.txt'
            ):            
                self.text_tsv_file = CompositeTSVFile(text_tsv_file)
                self._chunk_sizes = self.text_tsv_file.get_chunk_size()
            else:
                raise ValueError("Invalid input! Please check the tsv filenames.")
        # multiple tsv files specified in a list
        elif (
            isinstance(text_tsv_file, list)
        ):        
            self.text_tsv_file = CompositeTSVFile(text_tsv_file)
            self._chunk_sizes = self.text_tsv_file.get_chunk_size()
        else:
            raise ValueError("Invalid input! Please check the tsv filenames.")

    def get_chunk_sizes(self):
        return self._chunk_sizes

    def get_class_boundaries(self):
        # The samples of each class are organized class-by-class.
        # _class_boundaries stores the lower- and upper-bound of each class.
        return self.text_tsv_file.get_class_boundaries()

    def __getitem__(self, index):
        if isinstance(index, tuple):
            txt = self._load_one_data(index[0])
            if index[1] >= 0:
                tsv_filename = self.image_tsv_file.file_list[index[1]]

                # Python threads are not truly parallel. Spawn a new process instead.
                # logging.info('Pre-loading %s ...' % tsv_filename)
                # os.system('cat ' + tsv_filename + ' > /dev/null &')
                x = threading.Thread(
                   target=pre_fetch, args=(tsv_filename,), daemon=True
                )
                x.start()
        else:
            txt = self._load_one_data(index)
        
        data =  txt.lower()
        return data

    def _load_one_data(self, index):
        valid = False
        count = 0    
        items_text = self.text_tsv_file[index]
        _, txt = self._decode_text(items_text)

        return txt

    def _decode_text(self, items):
        key = items[0]
        text = ''

        if self.text_format == 'json':
            js = json.loads(items[1])
            assert 'captions' in js, '"captions" does not in {}'.format(js)
            captions = js['captions']
            if isinstance(captions, list):
                if self.num_captions == 1:
                    text = random.choice(captions)
                else:
                    text = captions
                    if len(captions) > self.num_captions:
                        text = captions[:self.num_captions]
            elif isinstance(captions, str):
                text = captions
            else:
                raise ValueError('captions should be str or list')
        else:
            text = items[1]

        return key, text

    def __len__(self):
        # return 40 #
        return len(self.text_tsv_file)

def gen_rows(filename, indexes):
    t = TSVFile(filename)

    for idx in indexes:
        yield t[idx]

parser = ArgumentParser()
parser.add_argument('--data_root', default='', type=str)
parser.add_argument('--dataset_name', default='conceptualcaption', type=str)
parser.add_argument('--dataset_phase', default='train', type=str)
parser.add_argument('--image_tsv_file', default='gcc-train-image-00.tsv,gcc-train-image-01.tsv', type=str)
parser.add_argument('--text_tsv_file', default='gcc-train-text-00.tsv,gcc-train-text-01.tsv', type=str)
# parser.add_argument('--dataset_phase', default='val', type=str)
# parser.add_argument('--image_tsv_file', default='gcc-val-image.tsv', type=str)
# parser.add_argument('--text_tsv_file', default='gcc-val-text.tsv', type=str)
args = parser.parse_args()
                                             
TARGET_WORDS = 'data/tools/conceptual_caption_choosed_words_2.txt'
if isinstance(TARGET_WORDS, str):
    if os.path.isfile(TARGET_WORDS):
        with open(TARGET_WORDS) as f:
            lines = f.readlines()
            lines = [l.replace('\n', '') for l in lines]
            lines = [l for l in lines if len(l) > 0]
            TARGET_WORDS = set(lines)
            f.close()
    else:
        TARGET_WORDS = set([TARGET_WORDS])

tsv_dataset = TSVTextDataset(os.path.join(args.dataset_name, args.dataset_phase), 
                             args.text_tsv_file.split(','), 
                             data_root=args.data_root, 
                             text_format='json')
tsv_loader = DataLoader(tsv_dataset, batch_size=1, shuffle=False, num_workers=2)

word_freq = defaultdict(int)
curr_remain_idxs = []

max_count_per_target = 28000 # 3000
word_count = defaultdict(int)

# save filtered captions
filtered_captions = os.path.join('data/captions', 'filtered_{}_{}.txt'.format(args.dataset_name, args.dataset_phase))
os.makedirs(os.path.dirname(filtered_captions), exist_ok=True)
filtered_captions_writer = open(filtered_captions, 'w')
print('searching good text idxs')
for idx, text in enumerate(tsv_loader):
    if idx % 1000 == 0:
        print('%d/%d' % (idx, len(tsv_loader)))
    text = text[0].lower()
    words = set(text.split())
    
    overlap = words & TARGET_WORDS
    valid = False
    for w in overlap:
        # if word_count[w] < max_count_per_target:
        #     valid = True
        #     word_count[w] += 1
        
        if word_freq[w] < max_count_per_target:
            valid = True
    if valid:
        filtered_captions_writer.write(text+'\n')
        curr_remain_idxs.append(idx)
        for word in words:
            word_freq[word]+=1
    
    # if idx > 20000:
    #     break

print('====> total %d / %d satisfied,  start writing tsv files' % (len(curr_remain_idxs), len(tsv_loader)))

filtered_captions_writer.close()
print('filtered captions saved to {}'.format(filtered_captions))

# save captions and world frequency
word = []
count = []
for k, v in word_freq.items():
    word.append(k)
    count.append(v)
# sort according to count
word_freq_ = OrderedDict()
index = np.argsort(count)
for idx in range(len(index)-1, -1, -1):
    idx = index[idx]
    word_freq_[word[idx]] = word_freq[word[idx]]
word_freq = word_freq_
out_word_frequency = os.path.join('data/tools/word_frequency', 'filtered_{}_{}.yaml'.format(args.dataset_name, args.dataset_phase))
os.makedirs(os.path.dirname(out_word_frequency), exist_ok=True)
save_config_to_yaml(word_freq, out_word_frequency)
print('filtered word frequency saved to {}'.format(out_word_frequency))

out_index = os.path.join('.cache/filter_data', args.dataset_name, args.dataset_phase, 'filtered_{}_index.txt'.format(args.dataset_phase))
os.makedirs(os.path.dirname(out_index), exist_ok=True)
with open(out_index, 'w') as f:
    lines = [str(idx) for idx in curr_remain_idxs]
    f.write('\n'.join(lines))
    f.close()
print('filtered index saved to {}'.format(out_index))

# out_image_tsv_file = os.path.join(args.out_dir, 'filtered_%s_image_tsv_%03d.tsv' \
#                                                     % (args.dataset_name, file_idx))
# out_text_tsv_file = os.path.join(args.out_dir, 'filtered_%s_text_tsv_%03d.tsv' \
#                                                     % (args.dataset_name, file_idx))
# tsv_writer(gen_rows(os.path.join(args.data_root, args.dataset_name, img_tsv_list[0]), curr_remain_idxs), out_image_tsv_file)
# tsv_writer(gen_rows(os.path.join(args.data_root, args.dataset_name, text_tsv_list[0]), curr_remain_idxs), out_text_tsv_file)

# torch.save(word_freq, os.path.join(args.out_dir, 'word_freq_%03d.pth' % file_idx))






